/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <cuda_runtime.h>

////////////////////////////////////////////////////////////////////////////////////////////////////

// The attention mask types.
enum class TrtllmGenAttentionMaskType {
  // Dense mask.
  Dense = 0,
  // Causal mask.
  Causal,
  // Sliding window or chunked causal mask.
  SlidingOrChunkedCausal,
  // Custom mask.
  Custom
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// Helper functions to check the mask type.

#define ATTENTION_MASK_TYPE_FUNCTION(MaskType)                          \
  inline bool is##MaskType##Mask(TrtllmGenAttentionMaskType maskType) { \
    return (maskType == TrtllmGenAttentionMaskType::MaskType);          \
  }

ATTENTION_MASK_TYPE_FUNCTION(Dense)
ATTENTION_MASK_TYPE_FUNCTION(Causal)
ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal)
ATTENTION_MASK_TYPE_FUNCTION(Custom)

#undef ATTENTION_MASK_TYPE_FUNCTION

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class FmhaKernelType {
  // The context-phase kernels.
  Context = 0,
  // Choose the best generation kernel based on the heuristic:
  // use SwapsMmaAbForGeneration kernels when numHeadsQPerKv <= 16, otherwise
  // KeepsMmaAbForGeneration.
  Generation = 1,
  // Swap tensor A and tensor B of Mma, which only supports numHeadsQPerKv <= 16.
  SwapsMmaAbForGeneration,
  // Keep tensor A and tensor B of Mma.
  KeepsMmaAbForGeneration,
  // Speculative decoding (Medusa and Eagle) generation-phase attention kernels, where seqLenQ > 1.
  SpecDecodingGeneration
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// Helper functions to check the fmha kernel type.

#define FMHA_KERNEL_TYPE_FUNCTION(KernelType)                     \
  inline bool is##KernelType##Kernel(FmhaKernelType kernelType) { \
    return (kernelType == FmhaKernelType::KernelType);            \
  }

FMHA_KERNEL_TYPE_FUNCTION(Context)
FMHA_KERNEL_TYPE_FUNCTION(Generation)
FMHA_KERNEL_TYPE_FUNCTION(SwapsMmaAbForGeneration)
FMHA_KERNEL_TYPE_FUNCTION(KeepsMmaAbForGeneration)
FMHA_KERNEL_TYPE_FUNCTION(SpecDecodingGeneration)

#undef QKV_LAYOUT_FUNCTION

////////////////////////////////////////////////////////////////////////////////////////////////////

// Note that (batchSize, seqLen) dimensions will be packed as sumOfSeqLens without paddings for
// variable sequence lengths.
enum class QkvLayout {
  // SeparateQkv: separate Q, K and V buffers.
  // Each has the shape: [batchSize, seqLen, numHeads, headDim].
  SeparateQkv = 0,
  // PackedQkv: single buffer for Q, K and V.
  // Shape: [batchSize, seqLen, numHeadsQ + 2*numHeadsKv, headDim].
  PackedQkv,
  // Paged buffer for K and V. Its shape is [batchSize, 2, maxNumPagesPerSeq]. The 2 corresponds to
  // K
  // and V. That buffer stores the logical page index of the paged-KV memory pool. Each "page" of
  // that
  // pool is a contiguous buffer of shape [numHeadsKv, pageSize, headDim].
  PagedKv,
  // ContiguousKv:
  // Contiguous buffer for Q with shape [batchSize, seqLen, numHeads, headDim].
  // Contiguous buffer for Kv with shape [batchSize, seqLen, 2 * numHeads, headDim].
  ContiguousKv,
};

// Helper functions to check the QkvLayout type.

#define QKV_LAYOUT_FUNCTION(LayoutType) \
  inline bool is##LayoutType(QkvLayout qkvLayout) { return (qkvLayout == QkvLayout::LayoutType); }

QKV_LAYOUT_FUNCTION(SeparateQkv)
QKV_LAYOUT_FUNCTION(PackedQkv)
QKV_LAYOUT_FUNCTION(PagedKv)
QKV_LAYOUT_FUNCTION(ContiguousKv)

#undef QKV_LAYOUT_FUNCTION

////////////////////////////////////////////////////////////////////////////////////////////////////

enum class TileScheduler {
  // Static scheduler (Non-persistent).
  Static = 0,
  // Persistent scheduler.
  Persistent
};

////////////////////////////////////////////////////////////////////////////////////////////////////
enum class MultiCtasKvMode {
  // No multiCtasKvMode.
  Disabled = 0,
  // Do the reduction through the global memory and atomic counters.
  GmemReduction,
  // Do the reduction through the CGA remote shared memory.
  CgaSmemReduction
};

// Helper function to check if the multiCtasKv is enabled.
inline bool isMultiCtasKvEnabled(MultiCtasKvMode multiCtasKvMode) {
  return multiCtasKvMode != MultiCtasKvMode::Disabled;
}

// Helper function to check the multiCtasKvMode type.

#define MULTI_CTAS_KV_MODE_FUNCTION(Type)                 \
  inline bool is##Type(MultiCtasKvMode multiCtasKvMode) { \
    return (multiCtasKvMode == MultiCtasKvMode::Type);    \
  }

MULTI_CTAS_KV_MODE_FUNCTION(Disabled)
MULTI_CTAS_KV_MODE_FUNCTION(GmemReduction)
MULTI_CTAS_KV_MODE_FUNCTION(CgaSmemReduction)

#undef MULTI_CTAS_KV_MODE_FUNCTION

////////////////////////////////////////////////////////////////////////////////////////////////////
struct TllmGenFmhaRunnerParams {
  // Input layout.
  QkvLayout mQkvLayout;
  // Attention mask type.
  TrtllmGenAttentionMaskType mMaskType;
  // The kernel type.
  FmhaKernelType mKernelType;
  // The tile scheduler.
  TileScheduler mTileScheduler;
  // The multiCtasKvMode (i.e. multiBlockMode).
  bool mMultiCtasKvMode;

  // Input QKV buffers.
  void const* qPtr;
  void const* kPtr;
  void const* vPtr;
  // Packed KV buffer
  void const* kvPtr;
  // Packed QKV buffer
  void const* qkvPtr;
  // The scaling factor pointer of K.
  void const* kSfBasePtr;
  // The scaling factor pointer of V.
  void const* vSfBasePtr;
  // The custom mask ptr.
  uint32_t const* customMaskPtr;
  // The packed custom mask's offsets of each sequence.
  int64_t const* customMaskOffsetsPtr;
  // The first sparseMask offsets in the Kv sequence dimension.
  int32_t const* firstSparseMaskOffsetsKvPtr;
  // The counter for the multiCtasKv mode.
  int32_t* multiCtasKvCounterPtr;
  // The sequence length buffer for K/V.
  int const* seqLensKvPtr;
  // The cumulative sequence length buffer for Q and K/V
  int const* cumSeqLensQPtr;
  int const* cumSeqLensKvPtr;
  // The kv page idx
  int const* kvPageIdxPtr;
  bool useGmemScale;
  // The device output scale for FP8 quantization.
  float const* outputScalePtr;
  float outputScale;
  // The device scaling factor for softmax (multiplied by log2 to use faster exp2)
  float const* scaleSoftmaxLog2Ptr;
  float scaleSoftmaxLog2;
  // The device scale for KV scaling factor.
  float const* kvSfScalePtr;
  // The device scale for O scaling factor.
  float const* oSfScalePtr;
  // The scratch space for each CtaKv when the multiCtasKv mode is enabled.
  // PartialO, partialMax and partialSum will be stored to the scratch space.
  void* multiCtasKvScratchPtr;
  // The softmax stats buffer.
  // The softmax max/sum values will be stored to the buffer if it is not nullptr.
  float2* softmaxStatsPtr;
  // The output buffer.
  void* oPtr;
  // The output scaling factor buffer.
  void* oSfPtr;

  // Head dimension for Q and K.
  int mHeadDimQk;
  // Head dimension for V.
  int mHeadDimV;
  // Number of heads for Q and K/V.
  int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
  // The batch size.
  int mBatchSize;
  // The max sequence length in the contiguous Kv cache.
  int mMaxSeqLenCacheKv;
  // The max q sequence length.
  int mMaxSeqLenQ;
  // The max kv sequence length.
  int mMaxSeqLenKv;
  // The attention window size for sliding window attention (sliding-window-attention is enabled
  // when seqLenKv > mAttentionWindowSize).
  int mAttentionWindowSize;
  // The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize).
  int mChunkedAttentionSize;
  // The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true)
  int mSumOfSeqLensQ;
  int mSumOfSeqLensKv;
  // The maximum number of pages per sequence in the paged-kv buffer.
  int mMaxNumPagesPerSeqKv;
  // The number of tokens per pageKv.
  int mNumTokensPerPage;
  // The number of pages in memory pool.
  int mNumPagesInMemPool;
  // The number of multiProcessor for the GPU.
  int mMultiProcessorCount;
  // Scaling factor for Q.
  float mScaleQ;
  // The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
  // kernel when inflight batching is enabled.
  int mSfStartTokenIdx;

  // The SF scale for Kv.
  float mScaleSfKv;
  // The cuda stream.
  cudaStream_t stream;

  // set the attention mask type
  TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType) {
    // maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
    // convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
    switch (maskType) {
      case 0:  // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
        mMaskType = TrtllmGenAttentionMaskType::Dense;
        break;
      case 1:  // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
        mMaskType = TrtllmGenAttentionMaskType::Causal;
        break;
      case 2:  // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL
        mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
        break;
      case 3:  // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
        mMaskType = TrtllmGenAttentionMaskType::Custom;
        break;
      default:
        // TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
        //     static_cast<int>(maskType));
        printf("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
               static_cast<int>(maskType));
    }
    return *this;
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

// Parameters that might be updated when selecting kernels.

struct TllmGenSelectKernelParams {
  // The FMHA kernel type.
  FmhaKernelType mKernelType;
  // The headDimV per CTA, which is only used by MLA generation kernels currently.
  int mHeadDimPerCtaV;
  // The multiCtasKvMode.
  MultiCtasKvMode mMultiCtasKvMode;
  // Force using GmemRedution for the multiCtasKvMode.
  bool mForceGmemReduction;
  // The mask type.
  TrtllmGenAttentionMaskType mMaskType;
  // Reuse smemK for V or not (only work with MLA generation kernels).
  bool mReuseSmemKForV;
  // Do we need to select a new kernel as the parameters have been updated.
  bool mSelectNewKernel;
  // The tile scheduler.
  TileScheduler mTileScheduler;
  // The tile size for Kv.
  int mTileSizeKv;
  // Use 2 CTA MMA or not.
  bool mUses2CtaMma;

  // The constructor.
  TllmGenSelectKernelParams(TllmGenFmhaRunnerParams params)
      : mKernelType(params.mKernelType),
        mHeadDimPerCtaV(params.mHeadDimV)
        // Note the CgaSmemReduction will be enabled based on the heuristic.
        ,
        mMultiCtasKvMode(params.mMultiCtasKvMode ? MultiCtasKvMode::GmemReduction
                                                 : MultiCtasKvMode::Disabled),
        mForceGmemReduction(false),
        mMaskType(params.mMaskType),
        mReuseSmemKForV(false),
        mSelectNewKernel(false),
        mTileScheduler(params.mTileScheduler),
        mTileSizeKv(128),
        mUses2CtaMma(false) {};
};
